import numpy as np
import warnings
warnings.filterwarnings('ignore')
import glob
from PIL import Image
import os
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as dset
import torchvision.utils as vutils
from torchvision.utils import make_grid
import pandas as pd
from IPython.display import HTML
from tqdm.auto import tqdm
from torchvision.models import inception_v3
from torch.cuda.amp import GradScaler, autocast
def is_cuda():
if torch.cuda.is_available():
print("CUDA available")
return "cuda"
else:
print("No CUDA. Working on CPU.")
return "cpu"
device = is_cuda()
root = "../input/tomjerrysc/"
batch_size = 8
image_size = 256
nc = 3 # n channels
nz = 512 # n latent dim
ngf = 64 # size of generator feature map
ndf = 64 # size of discriminator feature map
lr = 0.0001
beta1 = 0.05
ngpu = 1
def show_tensor_images(image_tensor, num_images=8, size=(3, 64, 64), nrow=4, figsize=8):
image_tensor = (image_tensor + 1) / 2
image_unflat = image_tensor.detach().cpu()
image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
plt.figure(figsize=(figsize, figsize))
plt.imshow(image_grid.permute(1, 2, 0).squeeze())
plt.show()
def to_rgb(img):
rgb_img = Image.new("RGB", img.size)
rgb_img.paste(img)
return rgb_img
class ImageSet(Dataset):
def __init__(self, root, transform):
self.root = root
self.transform = transform
self.imgs = sorted(glob.glob(os.path.join(root, "*.*")))
def __getitem__(self, index):
img = Image.open(self.imgs[index % len(self.imgs)])
img = to_rgb(img)
img = self.transform(img)
return img
def __len__(self):
return len(self.imgs)
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = ImageSet(root=root, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
real_batch = next(iter(dataloader))
show_tensor_images(real_batch)
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
inception = inception_v3(pretrained=True) # For computation of FID score
class SLE(nn.Module):
def __init__(self, in_channel):
super().__init__()
self.block = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=4),
nn.Conv2d(in_channel, in_channel, 4, 1, 0),
nn.LeakyReLU(0.1),
nn.Conv2d(in_channel, in_channel//8, 1, 1, 0),
nn.Sigmoid()
)
def forward(self, high, low):
x = self.block(low)
return high * x
def make_noise(n_samples=batch_size, z_dim=256, device="cuda"):
noise = torch.randn(n_samples, 256, device=device)
return noise[:,:,None,None]
class Generator(nn.Module):
def __init__(self, z_dim=256, out_res=256):
super().__init__()
assert out_res == 256, "Only Output Resolution of 256x256 Implemented, got {}".format(out_res)
self.block1 = nn.Sequential(
nn.ConvTranspose2d(z_dim, z_dim, 4, 1, 0),
nn.BatchNorm2d(z_dim),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(z_dim, 2*z_dim, 3, 1, 1),
nn.BatchNorm2d(2*z_dim),
nn.ReLU()
)
self.block2 = self.make_block(2*z_dim, z_dim)
self.block3 = self.make_block(z_dim, z_dim//2)
self.block4 = self.make_block(z_dim//2, z_dim//4)
self.block5 = self.make_block(z_dim//4, z_dim//4)
self.block6 = self.make_block(z_dim//4, z_dim//8)
self.out = nn.Sequential(
nn.Conv2d(z_dim//8, 3, 3, 1, 1),
nn.Tanh()
)
self.SLE1 = SLE(512)
self.SLE2 = SLE(256)
def make_block(self, in_channel, out_channel):
block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(in_channel, out_channel, 3, 1, 1),
nn.BatchNorm2d(out_channel),
nn.ReLU()
)
return block
def forward(self, x):
h1 = self.block1(x) # 512 x 8 x 8
h2 = self.block2(h1) # 256 x 16 x 16
x = self.block3(h2) # 128 x 32 x 32
x = self.block4(x) # 64 x 64 x 64
x = self.block5(x) # 64 x 128 x 128
x = self.SLE1(x, h1) # 64 x 128 x 128
x = self.block6(x) # 32 x 256 x 256
x = self.SLE2(x, h2) # 32 x 256 x 256
x = self.out(x) # 3 x 256 x 256
return x
class Decoder(nn.Module):
def __init__(self, in_feature=32):
super().__init__()
self.in_feature = in_feature
g = []
for _ in range(3):
g += [self.make_block(in_feature)]
g += [self.make_block(3)]
self.decoder = nn.Sequential(*g)
def make_block(self, out_feature):
block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(self.in_feature, out_feature, 3, 1, 1),
nn.BatchNorm2d(out_feature),
nn.ReLU()
)
return block
def forward(self, x):
return self.decoder(x)
class Discriminator(nn.Module):
def __init__(self, hidden_dim=64, in_res=256):
super().__init__()
assert in_res == 256, "Only Output Resolution of 256x256 Implemented, got {}".format(in_res)
self.block1 = nn.Sequential(
nn.Conv2d(3, hidden_dim//2, 4, 2, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(hidden_dim//2, hidden_dim, 4, 2, 1),
nn.BatchNorm2d(hidden_dim),
nn.LeakyReLU(0.1)
)
self.block2 = self.make_block(hidden_dim, hidden_dim)
self.skip2 = self.down_sample(hidden_dim, hidden_dim)
self.block3 = self.make_block(hidden_dim, hidden_dim//2)
self.skip3 = self.down_sample(hidden_dim, hidden_dim//2)
self.block4 = self.make_block(hidden_dim//2, hidden_dim//2)
self.skip4 = self.down_sample(hidden_dim//2, hidden_dim//2)
self.out = nn.Sequential(
nn.Conv2d(hidden_dim//2, hidden_dim//4, 1, 1, 0),
nn.BatchNorm2d(hidden_dim//4),
nn.LeakyReLU(0.1),
nn.Conv2d(hidden_dim//4, 1, 4, 1, 0)
)
self.decoder1 = Decoder()
self.decoder2 = Decoder()
def make_recon(self, recon=True):
self.recon = recon
def forward(self, x):
y = self.block1(x)
y1 = self.block2(y)
y2 = self.skip2(y)
y = y1 + y2
y1 = self.block3(y)
y2 = self.skip3(y)
h1 = y1 + y2 # 32 x 16 x 16 : For cropping
y1 = self.block4(h1)
y2 = self.skip4(h1)
# Simply center crop for now, where the literature implemented random crop
if len(h1.shape)==4:
h1 = h1[:,:,4:12, 4:12]
elif len(h1.shape)==3:
h1 = h1[:, 4:12, 4:12]
else:
print("invalid shape for feature map to be cropped, {}".format(h1.shape))
h2 = y1 + y2 # 32 x 8 x 8
y = self.out(h2) # 1 x 5 x 5
if self.recon is True:
y_part = self.decoder1(h1)
y_recon = self.decoder2(h2)
# y: 5 x 5 true/false
# y_part: reconstructed image from center y_part
# y_recon: reconstructed image from whole feature map
return y, y_part, y_recon
else:
return y
def make_block(self, in_channel, out_channel):
block = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 4, 2, 1),
nn.BatchNorm2d(out_channel),
nn.LeakyReLU(0.1),
nn.Conv2d(out_channel, out_channel, 3, 1, 1),
nn.BatchNorm2d(out_channel),
nn.LeakyReLU(0.1)
)
return block
def down_sample(self, in_channel, out_channel):
block = nn.Sequential(
nn.AvgPool2d(2, 2),
nn.Conv2d(in_channel, out_channel, 1, 1, 0),
nn.BatchNorm2d(out_channel),
nn.LeakyReLU(0.1)
)
return block
def hinge_loss(output, real=True):
return -torch.mean(torch.min(torch.zeros_like(output), -1+output)) if real else -torch.mean(torch.min(torch.zeros_like(output), -1-output))
def recon_loss(output, target):
return torch.mean(torch.norm(output-target))
def gen_loss(output):
return -torch.mean(output)
G = Generator()
G.apply(weights_init)
G.to(device)
D = Discriminator()
D.apply(weights_init)
D.to(device)
G_optim = optim.Adam(G.parameters())
D_optim = optim.Adam(D.parameters())
scaler1 = GradScaler()
scaler2 = GradScaler()
D_l, G_l = [], []
imgs_list = []
fixed_noise = make_noise()
cur_iter = 0
num_iters = 15000
while cur_iter < num_iters:
for real in tqdm(dataloader):
real = real.to(device)
D_optim.zero_grad()
D.make_recon(True)
with autocast():
D_real_pred, I_part, I_glob = D(real)
D_real_loss = hinge_loss(D_real_pred)
noise = make_noise()
fake = G(noise)
D.make_recon(False)
D_fake_pred = D(fake)
D_fake_loss = hinge_loss(D_fake_pred, real=False)
if len(real.shape)==4:
real_part = real[:,:,64:192, 64:192]
elif len(real.shape)==3:
real_part = real[:,64:192, 64:192]
else:
print("Invalid real shape, {}".format(real.shape))
real_glob = F.interpolate(real, scale_factor=0.5)
D_recon_loss = recon_loss(I_part, real_part) + recon_loss(I_glob, real_glob)
D_loss = D_real_loss + D_fake_loss + D_recon_loss
D_l.append(D_loss.item())
scaler1.scale(D_loss).backward()
scaler1.step(D_optim)
scaler1.update()
G_optim.zero_grad()
with autocast():
noise = make_noise()
fake = G(noise)
D_fake_pred = D(fake)
G_loss = gen_loss(D_fake_pred)
G_l.append(G_loss.item())
scaler2.scale(G_loss).backward()
scaler2.step(G_optim)
scaler2.update()
cur_iter += 1
if (cur_iter) % 1000 == 0:
print("{} / {}, D_loss: {:.4f}, G_loss: {:.4f}".format(cur_iter, num_iters, D_loss.item(), G_loss.item()))
noise = make_noise()
fake = G(noise)
show_tensor_images(fake)
imgs_list.append(G(fixed_noise).detach().cpu())
torch.save(G.state_dict(), "G.pt")
torch.save(D.state_dict(), "D.pt")
del D_fake_pred, D_real_pred, I_part, I_glob, fake
torch.cuda.empty_cache()
from torchvision.models import inception_v3
from torch.distributions import MultivariateNormal
import scipy
from scipy import linalg
inception.fc = nn.Identity()
inception.to(device)
# resnet = models.resnet50(pretrained=True)
def matrix_sqrt(x):
y = x.cpu().detach().numpy()
y = linalg.sqrtm(y)
return torch.Tensor(y.real, device=x.device)
def frechet_distance(mu_x, mu_y, sig_x, sig_y):
return torch.norm(mu_x-mu_y).pow(2) + torch.trace(sig_x+sig_y-2*matrix_sqrt(torch.matmul(sig_x, sig_y)))
def preprocess(img):
return F.interpolate(img, size=(299,299), mode='bilinear', align_corners=False)
def get_cov(x):
return torch.Tensor(np.cov(x.detach().numpy(), rowvar=False))
# fake_lst, real_lst = [], []
# G.eval()
# n_samples=10000
# batch_size=4
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# cur_samples=0
# with torch.no_grad():
# for real_example in tqdm(dataloader, total=n_samples // batch_size): # Go by batch
# real_samples = preprocess(real_example)
# real_features = inception(real_samples.to(device)) # Move features to CPU
# real_lst.append(torch.Tensor(real_features[0].cpu()))
# fake_samples = make_noise()
# fake_samples = preprocess(G(fake_samples))
# fake_features = inception(fake_samples.to(device))
# fake_lst.append(torch.Tensor(fake_features[0].cpu()))
# cur_samples += len(real_samples)
# if cur_samples >= n_samples:
# break
# fake_features_all = torch.cat(fake_lst)
# real_features_all = torch.cat(real_lst)
# mu_fake = fake_features_all.mean(0)
# mu_real = real_features_all.mean(0)
# sigma_fake = get_cov(fake_features_all)
# sigma_real = get_cov(real_features_all)
# with torch.no_grad():
# print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())
def slerp(val, low, high):
low_norm = low/torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm*high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
return res
z1 = make_noise(n_samples=4)
z2 = make_noise(n_samples=4)
zs = torch.cat([slerp(v, z1, z2) for v in np.arange(0.1, 1, 0.1)])
zs = torch.cat([zs[4*k,:,:,:].unsqueeze(0) for k in range(9)]+[zs[4*k+1,:,:,:].unsqueeze(0) for k in range(9)]+
[zs[4*k+2,:,:,:].unsqueeze(0) for k in range(9)]+[zs[4*k+3,:,:,:].unsqueeze(0) for k in range(9)])
show_tensor_images(G(zs), num_images=36, nrow=9, figsize=16)
# fig = plt.figure(figsize=(8,8))
# plt.axis("off")
# ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in imgs_list]
# ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
# HTML(ani.to_jshtml())